# Copyright (c) Company Co., Ltd.
# All rights reserved.
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

from cbml_benchmark.losses.registry import LOSS


@LOSS.register('ms_trans_loss')
class MultiSimilarityTransLoss(nn.Module):
    def __init__(self, cfg):
        super(MultiSimilarityTransLoss, self).__init__()

        self.T = 10

    def kl_div(self, A, B):
        log_p_A = F.log_softmax(A/self.T, dim=-1)
        p_B     = F.softmax(B/self.T, dim=-1)
        kl_div  = F.kl_div(log_p_A, p_B, reduction='sum') * (self.T**2) / A.shape[0]
        return kl_div

    def forward(self, feats, source):


        source = F.normalize(source, p=2, dim=1)

        sim_s = torch.matmul(source, torch.t(source))

        feats = F.normalize(feats, p=2, dim=1)

        sim_t = torch.matmul(feats, torch.t(feats))

        deta = 1.06
        loss = torch.abs((1.-sim_s)-deta*(1.-sim_t))
        loss = loss.mean()

        # loss = self.kl_div(sim_s,sim_t)

        return loss
